Skip to content

Conversation

@LU-JOHN
Copy link
Contributor

@LU-JOHN LU-JOHN commented May 1, 2025

Convert vector 64-bit lshr to 32-bit if shift amt is known to be >= 32.
Also convert scalar 64-bit lshr to 32-bit if shift amt is variable but known to be >=32.

@llvmbot
Copy link
Member

llvmbot commented May 1, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: None (LU-JOHN)

Changes

Convert vector 64-bit lshr to 32-bit if shift amt is known to be >= 32.
Also convert scalar 64-bit lshr to 32-bit if shift amt is variable but known to be >=32.


Patch is 34.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138204.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+96-37)
  • (modified) llvm/test/CodeGen/AMDGPU/mad_64_32.ll (+30-47)
  • (added) llvm/test/CodeGen/AMDGPU/srl64_reduce.ll (+561)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 236c373e70250..304b3cbf7cf2e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -4176,50 +4176,110 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
 
 SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
                                                 DAGCombinerInfo &DCI) const {
-  auto *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
-  if (!RHS)
-    return SDValue();
-
+  SDValue RHS = N->getOperand(1);
+  ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
   EVT VT = N->getValueType(0);
   SDValue LHS = N->getOperand(0);
-  unsigned ShiftAmt = RHS->getZExtValue();
   SelectionDAG &DAG = DCI.DAG;
   SDLoc SL(N);
+  unsigned RHSVal;
+
+  if (CRHS) {
+    RHSVal = CRHS->getZExtValue();
 
-  // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
-  // this improves the ability to match BFE patterns in isel.
-  if (LHS.getOpcode() == ISD::AND) {
-    if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
-      unsigned MaskIdx, MaskLen;
-      if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
-          MaskIdx == ShiftAmt) {
-        return DAG.getNode(
-            ISD::AND, SL, VT,
-            DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0), N->getOperand(1)),
-            DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1), N->getOperand(1)));
+    // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
+    // this improves the ability to match BFE patterns in isel.
+    if (LHS.getOpcode() == ISD::AND) {
+      if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
+        unsigned MaskIdx, MaskLen;
+        if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
+            MaskIdx == RHSVal) {
+          return DAG.getNode(ISD::AND, SL, VT,
+                             DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0),
+                                         N->getOperand(1)),
+                             DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1),
+                                         N->getOperand(1)));
+        }
       }
     }
   }
 
-  if (VT != MVT::i64)
+  // If the shift is exact, the shifted out bits matter.
+  if (N->getFlags().hasExact())
     return SDValue();
 
-  if (ShiftAmt < 32)
+  if (VT.getScalarType() != MVT::i64)
     return SDValue();
 
-  // srl i64:x, C for C >= 32
-  // =>
-  //   build_pair (srl hi_32(x), C - 32), 0
-  SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
+  // for C >= 32
+  // i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
 
-  SDValue Hi = getHiHalf64(LHS, DAG);
+  // On some subtargets, 64-bit shift is a quarter rate instruction. In the
+  // common case, splitting this into a move and a 32-bit shift is faster and
+  // the same code size.
+  KnownBits Known = DAG.computeKnownBits(RHS);
 
-  SDValue NewConst = DAG.getConstant(ShiftAmt - 32, SL, MVT::i32);
-  SDValue NewShift = DAG.getNode(ISD::SRL, SL, MVT::i32, Hi, NewConst);
+  EVT ElementType = VT.getScalarType();
+  EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
+  EVT TargetType = VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
+                                 : TargetScalarType;
 
-  SDValue BuildPair = DAG.getBuildVector(MVT::v2i32, SL, {NewShift, Zero});
+  if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
+    return SDValue();
 
-  return DAG.getNode(ISD::BITCAST, SL, MVT::i64, BuildPair);
+  SDValue ShiftAmt;
+  if (CRHS) {
+    ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
+                               TargetType);
+  } else {
+    SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
+    const SDValue ShiftMask =
+        DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
+    // This AND instruction will clamp out of bounds shift values.
+    // It will also be removed during later instruction selection.
+    ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
+  }
+
+  const SDValue Zero = DAG.getConstant(0, SL, TargetScalarType);
+  EVT ConcatType;
+  SDValue Hi;
+  SDLoc LHSSL(LHS);
+  // Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
+  if (VT.isVector()) {
+    unsigned NElts = TargetType.getVectorNumElements();
+    ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
+    SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
+    SmallVector<SDValue, 8> HiOps(NElts);
+    SmallVector<SDValue, 16> HiAndLoOps;
+
+    DAG.ExtractVectorElements(SplitLHS, HiAndLoOps, 0, NElts * 2);
+    for (unsigned I = 0; I != NElts; ++I) {
+      HiOps[I] = HiAndLoOps[2 * I + 1];
+    }
+    Hi = DAG.getNode(ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
+  } else {
+    const SDValue One = DAG.getConstant(1, LHSSL, TargetScalarType);
+    ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
+    SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
+    Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
+  }
+
+  SDValue NewShift = DAG.getNode(ISD::SRL, SL, TargetType, Hi, ShiftAmt);
+
+  SDValue Vec;
+  if (VT.isVector()) {
+    unsigned NElts = TargetType.getVectorNumElements();
+    SmallVector<SDValue, 8> LoOps;
+    SmallVector<SDValue, 16> HiAndLoOps(NElts * 2, Zero);
+
+    DAG.ExtractVectorElements(NewShift, LoOps, 0, NElts);
+    for (unsigned I = 0; I != NElts; ++I)
+      HiAndLoOps[2 * I] = LoOps[I];
+    Vec = DAG.getNode(ISD::BUILD_VECTOR, SL, ConcatType, HiAndLoOps);
+  } else {
+    Vec = DAG.getBuildVector(ConcatType, SL, {NewShift, Zero});
+  }
+  return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
 }
 
 SDValue AMDGPUTargetLowering::performTruncateCombine(
@@ -5198,22 +5258,21 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
 
     break;
   }
-  case ISD::SHL: {
+  case ISD::SHL:
+  case ISD::SRL: {
     // Range metadata can be invalidated when loads are converted to legal types
     // (e.g. v2i64 -> v4i32).
-    // Try to convert vector shl before type legalization so that range metadata
-    // can be utilized.
+    // Try to convert vector shl/srl before type legalization so that range
+    // metadata can be utilized.
     if (!(N->getValueType(0).isVector() &&
           DCI.getDAGCombineLevel() == BeforeLegalizeTypes) &&
         DCI.getDAGCombineLevel() < AfterLegalizeDAG)
       break;
-    return performShlCombine(N, DCI);
-  }
-  case ISD::SRL: {
-    if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
-      break;
-
-    return performSrlCombine(N, DCI);
+    if (N->getOpcode() == ISD::SHL) {
+      return performShlCombine(N, DCI);
+    } else {
+      return performSrlCombine(N, DCI);
+    }
   }
   case ISD::SRA: {
     if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
diff --git a/llvm/test/CodeGen/AMDGPU/mad_64_32.ll b/llvm/test/CodeGen/AMDGPU/mad_64_32.ll
index c5c95380fde9b..aa8c3ead474e0 100644
--- a/llvm/test/CodeGen/AMDGPU/mad_64_32.ll
+++ b/llvm/test/CodeGen/AMDGPU/mad_64_32.ll
@@ -1947,16 +1947,14 @@ define <2 x i64> @lshr_mad_i64_vec(<2 x i64> %arg0) #0 {
 ; CI-LABEL: lshr_mad_i64_vec:
 ; CI:       ; %bb.0:
 ; CI-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; CI-NEXT:    v_mov_b32_e32 v6, v3
-; CI-NEXT:    v_mov_b32_e32 v3, v1
-; CI-NEXT:    v_mov_b32_e32 v1, 0
 ; CI-NEXT:    s_mov_b32 s4, 0xffff1c18
-; CI-NEXT:    v_mad_u64_u32 v[4:5], s[4:5], v3, s4, v[0:1]
-; CI-NEXT:    v_mov_b32_e32 v3, v1
+; CI-NEXT:    v_mad_u64_u32 v[4:5], s[4:5], v1, s4, v[0:1]
 ; CI-NEXT:    s_mov_b32 s4, 0xffff1118
-; CI-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v6, s4, v[2:3]
+; CI-NEXT:    v_mad_u64_u32 v[6:7], s[4:5], v3, s4, v[2:3]
+; CI-NEXT:    v_sub_i32_e32 v1, vcc, v5, v1
+; CI-NEXT:    v_sub_i32_e32 v3, vcc, v7, v3
 ; CI-NEXT:    v_mov_b32_e32 v0, v4
-; CI-NEXT:    v_mov_b32_e32 v1, v5
+; CI-NEXT:    v_mov_b32_e32 v2, v6
 ; CI-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; SI-LABEL: lshr_mad_i64_vec:
@@ -1979,44 +1977,28 @@ define <2 x i64> @lshr_mad_i64_vec(<2 x i64> %arg0) #0 {
 ; GFX9-LABEL: lshr_mad_i64_vec:
 ; GFX9:       ; %bb.0:
 ; GFX9-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX9-NEXT:    v_mov_b32_e32 v6, v3
-; GFX9-NEXT:    v_mov_b32_e32 v3, v1
-; GFX9-NEXT:    v_mov_b32_e32 v1, 0
 ; GFX9-NEXT:    s_mov_b32 s4, 0xffff1c18
-; GFX9-NEXT:    v_mad_u64_u32 v[4:5], s[4:5], v3, s4, v[0:1]
-; GFX9-NEXT:    v_mov_b32_e32 v3, v1
+; GFX9-NEXT:    v_mad_u64_u32 v[4:5], s[4:5], v1, s4, v[0:1]
 ; GFX9-NEXT:    s_mov_b32 s4, 0xffff1118
-; GFX9-NEXT:    v_mad_u64_u32 v[2:3], s[4:5], v6, s4, v[2:3]
+; GFX9-NEXT:    v_mad_u64_u32 v[6:7], s[4:5], v3, s4, v[2:3]
+; GFX9-NEXT:    v_sub_u32_e32 v1, v5, v1
+; GFX9-NEXT:    v_sub_u32_e32 v3, v7, v3
 ; GFX9-NEXT:    v_mov_b32_e32 v0, v4
-; GFX9-NEXT:    v_mov_b32_e32 v1, v5
+; GFX9-NEXT:    v_mov_b32_e32 v2, v6
 ; GFX9-NEXT:    s_setpc_b64 s[30:31]
 ;
-; GFX1100-LABEL: lshr_mad_i64_vec:
-; GFX1100:       ; %bb.0:
-; GFX1100-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX1100-NEXT:    v_mov_b32_e32 v8, v3
-; GFX1100-NEXT:    v_dual_mov_b32 v6, v1 :: v_dual_mov_b32 v1, 0
-; GFX1100-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX1100-NEXT:    v_mad_u64_u32 v[4:5], null, 0xffff1c18, v6, v[0:1]
-; GFX1100-NEXT:    v_dual_mov_b32 v3, v1 :: v_dual_mov_b32 v0, v4
-; GFX1100-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
-; GFX1100-NEXT:    v_mad_u64_u32 v[6:7], null, 0xffff1118, v8, v[2:3]
-; GFX1100-NEXT:    v_dual_mov_b32 v1, v5 :: v_dual_mov_b32 v2, v6
-; GFX1100-NEXT:    s_delay_alu instid0(VALU_DEP_2)
-; GFX1100-NEXT:    v_mov_b32_e32 v3, v7
-; GFX1100-NEXT:    s_setpc_b64 s[30:31]
-;
-; GFX1150-LABEL: lshr_mad_i64_vec:
-; GFX1150:       ; %bb.0:
-; GFX1150-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GFX1150-NEXT:    v_dual_mov_b32 v4, v3 :: v_dual_mov_b32 v5, v1
-; GFX1150-NEXT:    v_mov_b32_e32 v1, 0
-; GFX1150-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_3)
-; GFX1150-NEXT:    v_mov_b32_e32 v3, v1
-; GFX1150-NEXT:    v_mad_u64_u32 v[0:1], null, 0xffff1c18, v5, v[0:1]
-; GFX1150-NEXT:    s_delay_alu instid0(VALU_DEP_2)
-; GFX1150-NEXT:    v_mad_u64_u32 v[2:3], null, 0xffff1118, v4, v[2:3]
-; GFX1150-NEXT:    s_setpc_b64 s[30:31]
+; GFX11-LABEL: lshr_mad_i64_vec:
+; GFX11:       ; %bb.0:
+; GFX11-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; GFX11-NEXT:    v_mad_u64_u32 v[4:5], null, 0xffff1c18, v1, v[0:1]
+; GFX11-NEXT:    v_mad_u64_u32 v[6:7], null, 0xffff1118, v3, v[2:3]
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_3)
+; GFX11-NEXT:    v_sub_nc_u32_e32 v1, v5, v1
+; GFX11-NEXT:    v_mov_b32_e32 v0, v4
+; GFX11-NEXT:    s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_4)
+; GFX11-NEXT:    v_sub_nc_u32_e32 v3, v7, v3
+; GFX11-NEXT:    v_mov_b32_e32 v2, v6
+; GFX11-NEXT:    s_setpc_b64 s[30:31]
 ;
 ; GFX12-LABEL: lshr_mad_i64_vec:
 ; GFX12:       ; %bb.0:
@@ -2025,13 +2007,14 @@ define <2 x i64> @lshr_mad_i64_vec(<2 x i64> %arg0) #0 {
 ; GFX12-NEXT:    s_wait_samplecnt 0x0
 ; GFX12-NEXT:    s_wait_bvhcnt 0x0
 ; GFX12-NEXT:    s_wait_kmcnt 0x0
-; GFX12-NEXT:    v_dual_mov_b32 v4, v3 :: v_dual_mov_b32 v5, v1
-; GFX12-NEXT:    v_mov_b32_e32 v1, 0
-; GFX12-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_3)
-; GFX12-NEXT:    v_mov_b32_e32 v3, v1
-; GFX12-NEXT:    v_mad_co_u64_u32 v[0:1], null, 0xffff1c18, v5, v[0:1]
-; GFX12-NEXT:    s_delay_alu instid0(VALU_DEP_2)
-; GFX12-NEXT:    v_mad_co_u64_u32 v[2:3], null, 0xffff1118, v4, v[2:3]
+; GFX12-NEXT:    v_mad_co_u64_u32 v[4:5], null, 0xffff1c18, v1, v[0:1]
+; GFX12-NEXT:    v_mad_co_u64_u32 v[6:7], null, 0xffff1118, v3, v[2:3]
+; GFX12-NEXT:    s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_3)
+; GFX12-NEXT:    v_sub_nc_u32_e32 v1, v5, v1
+; GFX12-NEXT:    v_mov_b32_e32 v0, v4
+; GFX12-NEXT:    s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_4)
+; GFX12-NEXT:    v_sub_nc_u32_e32 v3, v7, v3
+; GFX12-NEXT:    v_mov_b32_e32 v2, v6
 ; GFX12-NEXT:    s_setpc_b64 s[30:31]
   %lsh = lshr <2 x i64> %arg0, <i64 32, i64 32>
   %mul = mul <2 x i64> %lsh, <i64 s0xffffffffffff1c18, i64 s0xffffffffffff1118>
diff --git a/llvm/test/CodeGen/AMDGPU/srl64_reduce.ll b/llvm/test/CodeGen/AMDGPU/srl64_reduce.ll
new file mode 100644
index 0000000000000..eba894bfad57c
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/srl64_reduce.ll
@@ -0,0 +1,561 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+;; Test reduction of:
+;;
+;;   DST = lshr i64 X, Y
+;;
+;; where Y is in the range [63-32] to:
+;;
+;;   DST = [srl i32 X, (Y & 0x1F), 0]
+
+; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 < %s | FileCheck %s
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+; Test range with metadata
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+define i64 @srl_metadata(i64 %arg0, ptr %arg1.ptr) {
+; CHECK-LABEL: srl_metadata:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dword v0, v[2:3]
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_lshrrev_b32_e32 v0, v0, v1
+; CHECK-NEXT:    v_mov_b32_e32 v1, 0
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %shift.amt = load i64, ptr %arg1.ptr, !range !0, !noundef !{}
+  %srl = lshr i64 %arg0, %shift.amt
+  ret i64 %srl
+}
+
+; Shifted bits matter for exact shift.  Reduction must not be done.
+define i64 @srl_exact_metadata(i64 %arg0, ptr %arg1.ptr) {
+; CHECK-LABEL: srl_exact_metadata:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dword v2, v[2:3]
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_lshrrev_b64 v[0:1], v2, v[0:1]
+; CHECK-NEXT:    v_mov_b32_e32 v1, 0
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %shift.amt = load i64, ptr %arg1.ptr, !range !0, !noundef !{}
+  %srl = lshr exact i64 %arg0, %shift.amt
+  ret i64 %srl
+}
+
+define i64 @srl_metadata_two_ranges(i64 %arg0, ptr %arg1.ptr) {
+; CHECK-LABEL: srl_metadata_two_ranges:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dword v0, v[2:3]
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_lshrrev_b32_e32 v0, v0, v1
+; CHECK-NEXT:    v_mov_b32_e32 v1, 0
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %shift.amt = load i64, ptr %arg1.ptr, !range !1, !noundef !{}
+  %srl = lshr i64 %arg0, %shift.amt
+  ret i64 %srl
+}
+
+; Known minimum is too low.  Reduction must not be done.
+define i64 @srl_metadata_out_of_range(i64 %arg0, ptr %arg1.ptr) {
+; CHECK-LABEL: srl_metadata_out_of_range:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dword v2, v[2:3]
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_lshrrev_b64 v[0:1], v2, v[0:1]
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %shift.amt = load i64, ptr %arg1.ptr, !range !2, !noundef !{}
+  %srl = lshr i64 %arg0, %shift.amt
+  ret i64 %srl
+}
+
+; Bounds cannot be truncated to i32 when load is narrowed to i32.
+; Reduction must not be done.
+; Bounds were chosen so that if bounds were truncated to i32 the
+; known minimum would be 32 and the srl would be erroneously reduced.
+define i64 @srl_metadata_cant_be_narrowed_to_i32(i64 %arg0, ptr %arg1.ptr) {
+; CHECK-LABEL: srl_metadata_cant_be_narrowed_to_i32:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dword v2, v[2:3]
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_lshrrev_b64 v[0:1], v2, v[0:1]
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %shift.amt = load i64, ptr %arg1.ptr, !range !3, !noundef !{}
+  %srl = lshr i64 %arg0, %shift.amt
+  ret i64 %srl
+}
+
+define <2 x i64> @srl_v2_metadata(<2 x i64> %arg0, ptr %arg1.ptr) {
+; CHECK-LABEL: srl_v2_metadata:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dwordx4 v[4:7], v[4:5]
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_lshrrev_b32_e32 v0, v4, v1
+; CHECK-NEXT:    v_lshrrev_b32_e32 v2, v6, v3
+; CHECK-NEXT:    v_mov_b32_e32 v1, 0
+; CHECK-NEXT:    v_mov_b32_e32 v3, 0
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %shift.amt = load <2 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
+  %srl = lshr <2 x i64> %arg0, %shift.amt
+  ret <2 x i64> %srl
+}
+
+; Shifted bits matter for exact shift.  Reduction must not be done.
+define <2 x i64> @srl_exact_v2_metadata(<2 x i64> %arg0, ptr %arg1.ptr) {
+; CHECK-LABEL: srl_exact_v2_metadata:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dwordx4 v[4:7], v[4:5]
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_lshrrev_b64 v[0:1], v4, v[0:1]
+; CHECK-NEXT:    v_lshrrev_b64 v[2:3], v6, v[2:3]
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %shift.amt = load <2 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
+  %srl = lshr exact <2 x i64> %arg0, %shift.amt
+  ret <2 x i64> %srl
+}
+
+define <3 x i64> @srl_v3_metadata(<3 x i64> %arg0, ptr %arg1.ptr) {
+; CHECK-LABEL: srl_v3_metadata:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dword v0, v[6:7] offset:16
+; CHECK-NEXT:    flat_load_dwordx4 v[8:11], v[6:7]
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_lshrrev_b32_e32 v4, v0, v5
+; CHECK-NEXT:    v_lshrrev_b32_e32 v0, v8, v1
+; CHECK-NEXT:    v_lshrrev_b32_e32 v2, v10, v3
+; CHECK-NEXT:    v_mov_b32_e32 v1, 0
+; CHECK-NEXT:    v_mov_b32_e32 v3, 0
+; CHECK-NEXT:    v_mov_b32_e32 v5, 0
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %shift.amt = load <3 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
+  %srl = lshr <3 x i64> %arg0, %shift.amt
+  ret <3 x i64> %srl
+}
+
+define <4 x i64> @srl_v4_metadata(<4 x i64> %arg0, ptr %arg1.ptr) {
+; CHECK-LABEL: srl_v4_metadata:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dwordx4 v[10:13], v[8:9]
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    flat_load_dwordx4 v[13:16], v[8:9] offset:16
+; CHECK-NEXT:    ; kill: killed $vgpr8 killed $vgpr9
+; CHECK-NEXT:    v_lshrrev_b32_e32 v0, v10, v1
+; CHECK-NEXT:    v_lshrrev_b32_e32 v2, v12, v3
+; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_lshrrev_b32_e32 v4, v13, v5
+; CHECK-NEXT:    v_lshrrev_b32_e32 v6, v15, v7
+; CHECK-NEXT:    v_mov_b32_e32 v1, 0
+; CHECK-NEXT:    v_mov_b32_e32 v3, 0
+; CHECK-NEXT:    v_mov_b32_e32 v5, 0
+; CHECK-NEXT:    v_mov_b32_e32 v7, 0
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %shift.amt = load <4 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
+  %srl = lshr <4 x i64> %arg0, %shift.amt
+  ret <4 x i64> %srl
+}
+
+!0 = !{i64 32, i64 64}
+!1 = !{i64 32, i64 38, i64 42, i64 48}
+!2 = !{i64 31, i64 38, i64 42, i64 48}
+!3 = !{i64 32, i64 38, i64 2147483680, i64 2147483681}
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+; Test range with an "or X, 16"
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+; These cases must not be reduced because the known minimum, 16, is not in range.
+
+define i64 @srl_or16(i64 %arg0, i64 %shift_amt) {
+; CHECK-LABEL: srl_or16:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_or_b32_e32 v2, 16, v2
+; CHECK-NEXT:    v_lshrrev_b64 v[0:1], v2, v[0:1]
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %or = or i64 %shift_amt, 16
+  %srl = lshr i64 %arg0, %or
+  ret i64 %srl
+}
+
+define <2 x i64> @srl_v2_or16(<2 x i64> %arg0, <2 x i64> %shift_amt) {
+; CHECK-LABEL: srl_v2_or16:
+; CHECK:       ; %bb.0:
+; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT:    v_or_b32_e32 v5, 16, v6
+; CHECK-NEXT:    v_or_b32_e32 v4, 16, v4
+; CHECK-NEXT:    v_lshrrev_b64 v[0:1], v4, v[0:1]
+; CHECK-NEXT:    v_lshrrev_b64 v[2:3], v5, v[2:3]
+; CHECK-NEXT:    s_setpc_b64 s[30:31]
+  %or = or <2 x i64> %shift_amt, splat (i64 16)
+  %srl = lshr <2 x i64> %arg0, %or
+  ret <2 x i...
[truncated]

Comment on lines 4243 to 4266
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks suspiciously long, should share code with the other shift case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not come up with a simpler way to code this. In the SHL case, extracting the low-half can be simply done with a TRUNCATE instruction. EXTRACT_ELEMENT works for the hi-half, but it does not work for vectors.

@LU-JOHN LU-JOHN requested a review from arsenm May 3, 2025 03:11
@LU-JOHN LU-JOHN marked this pull request as draft May 3, 2025 03:12
@LU-JOHN LU-JOHN marked this pull request as ready for review May 6, 2025 15:15
if (CRHS) {
RHSVal = CRHS->getZExtValue();

// fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this in DAG combine? This seems to be target dependent.


DAG.ExtractVectorElements(SplitLHS, HiAndLoOps, 0, NElts * 2);
for (unsigned I = 0; I != NElts; ++I) {
HiOps[I] = HiAndLoOps[2 * I + 1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can simply use insert.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel array indexing is clearer than using an insert. For reference look at the last commit #132964, which switched from insert to array indexing to address feedback.

return DAG.getNode(ISD::BITCAST, SL, MVT::i64, BuildPair);
DAG.ExtractVectorElements(NewShift, LoOps, 0, NElts);
for (unsigned I = 0; I != NElts; ++I)
HiAndLoOps[2 * I] = LoOps[I];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, insert

Signed-off-by: John Lu <[email protected]>
@LU-JOHN LU-JOHN requested a review from shiltian June 9, 2025 20:52
Copy link
Contributor

@shiltian shiltian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one nit

Signed-off-by: John Lu <[email protected]>
Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should turn this vector repacking into a utility function later

@arsenm arsenm merged commit c4caf00 into llvm:main Jun 13, 2025
7 checks passed
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…8204)

Convert vector 64-bit lshr to 32-bit if shift amt is known to be >= 32.
Also convert scalar 64-bit lshr to 32-bit if shift amt is variable but
known to be >=32.

---------

Signed-off-by: John Lu <[email protected]>
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
…8204)

Convert vector 64-bit lshr to 32-bit if shift amt is known to be >= 32.
Also convert scalar 64-bit lshr to 32-bit if shift amt is variable but
known to be >=32.

---------

Signed-off-by: John Lu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants